{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "h2q27gKz1H20"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "TUfAcER1oUS6"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Gb7qyhNL1yWt"
      },
      "source": [
        "# Text classification with TensorFlow Lite Model Maker with TensorFlow 2.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Fw5Y7snSuG51"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/demo/text_classification.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/demo/text_classification.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sr3q-gvm3cI8"
      },
      "source": [
        "The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.\n",
        "\n",
        "This notebook shows an end-to-end example that utilizes this Model Maker library to illustrate the adaption and conversion of a commonly-used text classification model to classify movie reviews on a mobile device."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bcLF2PKkSbV3"
      },
      "source": [
        "## Prerequisites\n",
        "\n",
        "To run this example, we first need to install several required packages, including Model Maker package that in github [repo](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "qhl8lqVamEty"
      },
      "outputs": [],
      "source": [
        "!pip install git+git://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "l6lRhVK9Q_0U"
      },
      "source": [
        "Import the required packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "XtxiUeZEiXpt"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "assert tf.__version__.startswith('2')\n",
        "\n",
        "from tensorflow_examples.lite.model_maker.core.data_util.text_dataloader import TextClassifierDataLoader\n",
        "from tensorflow_examples.lite.model_maker.core.model_export_format import ModelExportFormat\n",
        "from tensorflow_examples.lite.model_maker.core.task.model_spec import AverageWordVecModelSpec\n",
        "from tensorflow_examples.lite.model_maker.core.task.model_spec import BertModelSpec\n",
        "from tensorflow_examples.lite.model_maker.core.task import text_classifier"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "06sWWfvE6I8e"
      },
      "source": [
        "## Simple End-to-End Example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BRd13bfetO7B"
      },
      "source": [
        "### Get the data path\n",
        "Let's get some texts to play with this simple end-to-end example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "R2BSkxWg6Rhx"
      },
      "outputs": [],
      "source": [
        "data_path = tf.keras.utils.get_file(\n",
        "      fname='aclImdb',\n",
        "      origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',\n",
        "      untar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6MSCjPAvs2EQ"
      },
      "source": [
        " You could replace it with your own text folders. As for uploading data to colab, you could find the upload button in the left sidebar shown in the image below with the red rectangle. Just have a try to upload a zip file and unzip it. The root file path is the current path.\n",
        "\n",
        "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_text_classification.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uO5egTlrtWxm"
      },
      "source": [
        "If you prefer not to upload your images to the cloud, you could try to run the library locally following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker) in github."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WlKU3SMX6TnB"
      },
      "source": [
        "### Run the example\n",
        "\n",
        "The example just consists of 6 lines of code as shown below, representing 5 steps of the overall process."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PBPUIhEjMjTR"
      },
      "source": [
        "0. Choose a `model_spec` that represents a model for text classifier."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CtdZ-JDwMimd"
      },
      "outputs": [],
      "source": [
        "model_spec = AverageWordVecModelSpec()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "s5U-A3tw6Y27"
      },
      "source": [
        "1.   Load train and test data specific to an on-device ML app and preprocess the data according to specific `model_spec`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "HD5BvzWe6YKa"
      },
      "outputs": [],
      "source": [
        "train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=model_spec, class_labels=['pos', 'neg'])\n",
        "test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), model_spec=model_spec, is_training=False, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2uZkLR6N6gDR"
      },
      "source": [
        "2. Customize the TensorFlow model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kwlYdTcg63xy"
      },
      "outputs": [],
      "source": [
        "model = text_classifier.create(train_data, model_spec=model_spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-BzCHLWJ6h7q"
      },
      "source": [
        "3. Evaluate the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8xmnl6Yy7ARn"
      },
      "outputs": [],
      "source": [
        "loss, acc = model.evaluate(test_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CgCDMe0e6jlT"
      },
      "source": [
        "4.  Export to TensorFlow Lite  model.\n",
        "You could download it in the left sidebar same as the uploading part for your own use."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Hm_UULdW7A9T"
      },
      "outputs": [],
      "source": [
        "model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rVxaf3x_7OfB"
      },
      "source": [
        "After this simple 5 steps, we could further use TensorFlow Lite model file and label file in on-device applications like in [text classification](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification) reference app."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "l65ctmtW7_FF"
      },
      "source": [
        "## Detailed Process\n",
        "\n",
        "In the above, we tried the simple end-to-end example. The following walks through the example step by step to show more detail."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kJ_B8fMDOhMR"
      },
      "source": [
        "### Step 0: Choose a model_spec that represents a model for text classifier.\n",
        "\n",
        "each `model_spec` object represents a specific model for the text classifier. Currently, we support averging word embedding model and BERT-base model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vEAWuZQ1PFiX"
      },
      "outputs": [],
      "source": [
        "model_spec = AverageWordVecModelSpec()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ygEncJxtl-nQ"
      },
      "source": [
        "### Step 1: Load Input Data Specific to an On-device ML App\n",
        "\n",
        "The IMDB dataset contains 25000 movie reviews for training and 25000 movie reviews for testing from the [Internet Movie Database](https://www.imdb.com/). The dataset has two classes: positive and negative movie reviews.\n",
        "\n",
        "Download the archive version of the dataset and untar it.\n",
        "\n",
        "The IMDB dataset has the following directory structure:\n",
        "\n",
        "\u003cpre\u003e\n",
        "\u003cb\u003eaclImdb\u003c/b\u003e\n",
        "|__ \u003cb\u003etrain\u003c/b\u003e\n",
        "    |______ \u003cb\u003epos\u003c/b\u003e: [1962_10.txt, 2499_10.txt, ...]\n",
        "    |______ \u003cb\u003eneg\u003c/b\u003e: [104_3.txt, 109_2.txt, ...]\n",
        "    |______ unsup: [12099_0.txt, 1424_0.txt, ...]\n",
        "|__ \u003cb\u003etest\u003c/b\u003e\n",
        "    |______ \u003cb\u003epos\u003c/b\u003e: [1384_9.txt, 191_9.txt, ...]\n",
        "    |______ \u003cb\u003eneg\u003c/b\u003e: [1629_1.txt, 21_1.txt]\n",
        "\n",
        "\u003c/pre\u003e\n",
        "\n",
        "Note that the text data under `train/unsup` folder are unlabeled documents for unsupervised learning and such data should be ignored in this tutorial.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "7tOfUr2KlgpU"
      },
      "outputs": [],
      "source": [
        "data_path = tf.keras.utils.get_file(\n",
        "      fname='aclImdb',\n",
        "      origin='http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',\n",
        "      untar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E051HBUM5owi"
      },
      "source": [
        "Use `TextClassifierDataLoader` to load data.\n",
        "\n",
        "As for `from_folder()` method, it could load data from the folder. It assumes that the text data of the same class are in the same subdirectory and the subfolder name is the class name. Each text file contains one movie review sample.\n",
        "\n",
        "Parameter `class_labels` is used to specify which subfolder should be considered. As for `train` folder, this parameter is used to skip `unsup` subfolder.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "I_fOlZsklmlL"
      },
      "outputs": [],
      "source": [
        "train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=model_spec, class_labels=['pos', 'neg'])\n",
        "test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), model_spec=model_spec, is_training=False, shuffle=False)\n",
        "train_data, validation_data = train_data.split(0.9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AWuoensX4vDA"
      },
      "source": [
        "### Step 2: Customize the TensorFlow Model\n",
        "\n",
        "Create a custom text classifier model based on the loaded data. Currently, we support averaging word embedding and BERT-base model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TvYSUuJY3QxR"
      },
      "outputs": [],
      "source": [
        "model = text_classifier.create(train_data, model_spec=model_spec, validation_data=validation_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0JKI-pNc8idH"
      },
      "source": [
        "Have a look at the detailed model structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gd7Hs8TF8n3H"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LP5FPk_tOxoZ"
      },
      "source": [
        "### Step 3: Evaluate the Customized Model\n",
        "\n",
        "Evaluate the result of the model, get the loss and accuracy of the model.\n",
        "\n",
        "Evaluate the loss and accuracy in `test_data`. If no data is given the results are evaluated on the data that's splitted in the `create` method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "A8c2ZQ0J3Riy"
      },
      "outputs": [],
      "source": [
        "loss, acc = model.evaluate(test_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aeHoGAceO2xV"
      },
      "source": [
        "### Step 4: Export to TensorFlow Lite Model\n",
        "\n",
        "Convert the existing model to TensorFlow Lite model format that could be later used in on-device ML application. Meanwhile, save the text labels in label file and vocabulary in vocab file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Im6wA9lK3TQB"
      },
      "outputs": [],
      "source": [
        "model.export('movie_review_classifier.tflite', 'text_label.txt', 'vocab.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "w12kvDdHJIGH"
      },
      "source": [
        "The TensorFlow Lite model file and label file could be used in the [text classification](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification) reference app.\n",
        "\n",
        "In detail, we could add `movie_review_classifier.tflite`, `text_label.txt` and `vocab.txt` to the [assets directory](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/app/src/main/assets) folder. Meanwhile, change the filenames in [code](https://github.com/tensorflow/examples/blob/master/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java#L43). "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HZKYthlVrTos"
      },
      "source": [
        "Here, we also demonstrate how to use the above files to run and evaluate the TensorFlow Lite model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ochbq95ZrVFX"
      },
      "outputs": [],
      "source": [
        "# Read TensorFlow Lite model from TensorFlow Lite file.\n",
        "with tf.io.gfile.GFile('movie_review_classifier.tflite', 'rb') as f:\n",
        "  model_content = f.read()\n",
        "\n",
        "# Read label names from label file.\n",
        "with tf.io.gfile.GFile('text_label.txt', 'r') as f:\n",
        "  label_names = f.read().split('\\n')\n",
        "\n",
        "# Initialze TensorFlow Lite inpterpreter.\n",
        "interpreter = tf.lite.Interpreter(model_content=model_content)\n",
        "interpreter.allocate_tensors()\n",
        "input_index = interpreter.get_input_details()[0]['index']\n",
        "output = interpreter.tensor(interpreter.get_output_details()[0][\"index\"])\n",
        "\n",
        "# Run predictions on each test data and calculate accuracy.\n",
        "accurate_count = 0\n",
        "for text, label in test_data.dataset:\n",
        "    # Add batch dimension and convert to float32 to match with the model's input\n",
        "    # data format.\n",
        "    text = tf.expand_dims(text, 0)\n",
        "    text = tf.cast(text, tf.float32)\n",
        "\n",
        "    # Run inference.\n",
        "    interpreter.set_tensor(input_index, text)\n",
        "    interpreter.invoke()\n",
        "\n",
        "    # Post-processing: remove batch dimension and find the label with highest\n",
        "    # probability.\n",
        "    predict_label = np.argmax(output()[0])\n",
        "    # Get label name with label index.\n",
        "    predict_label_name = label_names[predict_label]\n",
        "    accurate_count += (predict_label == label.numpy())\n",
        "\n",
        "accuracy = accurate_count * 1.0 / test_data.size\n",
        "print('TensorFlow Lite model accuracy = %.4f' % accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KLKmboKFtgc2"
      },
      "source": [
        "Note that preprocessing for inference should be the same as training. Currently, preprocessing contains split the text to tokens by '\\W', encode the tokens to ids, the pad the text with `pad_id` to have the length of `seq_length`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EoWiA_zX8rxE"
      },
      "source": [
        "# Advanced Usage\n",
        "\n",
        "The `create` function is the critical part of this library in which parameter `model_spec` defines the specification of the model, currently `AverageWordVecModelSpec` and `BertModelSpec` is supported. The `create` function contains the following steps for `AverageWordVecModelSpec`:\n",
        "\n",
        "1.   Tokenize the text and select the top `num_words` most frequent words to generate the vocubulary. The default value of `num_words` in `AverageWordVecModelSpec` object is `10000`.\n",
        "2.   Encode the text string tokens to int ids.\n",
        "3.   Create the text classifier model. Currently, this library supports one model: average the word embedding of the text with RELU activation, then leverage softmax dense layer for classification. As for [Embedding layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding), the input dimension is the size of the vocabulary, the output dimension is `AverageWordVecModelSpec` object's variable `wordvec_dim` which default value is `16`, the input length is `AverageWordVecModelSpec` object's variable `seq_len` which default value is `256`.\n",
        "4.   Train the classifier model. The default epoch is `2` and the default batch size is `32`.\n",
        "\n",
        "In this section, we describe several advanced topics, including adjusting the model, changing the training hyperparameters etc.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mwtiksguDfhl"
      },
      "source": [
        "# Adjust the model\n",
        "\n",
        "We could adjust the model infrastructure like variables `wordvec_dim`, `seq_len` in `AverageWordVecModelSpec` class.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cAOd5_bzH9AQ"
      },
      "source": [
        "*   `wordvec_dim`: Dimension of word embedding.\n",
        "*   `seq_len`: length of sequence.\n",
        "\n",
        "For example, we could train with larger `wordvec_dim`. If we change the model, we need to construct the new `model_spec` firstly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "e9WBN0UTQoMN"
      },
      "outputs": [],
      "source": [
        "new_model_spec = AverageWordVecModelSpec(wordvec_dim=32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6LSTdghTP0Cv"
      },
      "source": [
        "Secondly, we should get the preprocessed data accordingly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DVZurFBORG3J"
      },
      "outputs": [],
      "source": [
        "new_train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=new_model_spec, class_labels=['pos', 'neg'])\n",
        "new_train_data, new_validation_data = new_train_data.split(0.9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tD7QVVHeRZoM"
      },
      "source": [
        "Finally, we could train the new model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PzpV246_JGEu"
      },
      "outputs": [],
      "source": [
        "model = text_classifier.create(new_train_data, model_spec=new_model_spec, validation_data=new_validation_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LvQuy7RSDir3"
      },
      "source": [
        "## Change the training hyperparameters\n",
        "We could also change the training hyperparameters like `epochs` and `batch_size` that could affect the model accuracy. For instance,\n",
        "\n",
        "*   `epochs`: more epochs could achieve better accuracy, but may lead to overfitting.\n",
        "*   `batch_size`: number of samples to use in one training step.\n",
        "\n",
        "For example, we could train with more epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rnWFaYZBG6NW"
      },
      "outputs": [],
      "source": [
        "model = text_classifier.create(train_data, model_spec=model_spec, validation_data=validation_data, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nUaKQZBQHBQR"
      },
      "source": [
        "Evaluate the newly retrained model with 5 training epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BMPi1xflHDSY"
      },
      "outputs": [],
      "source": [
        "loss, accuracy = model.evaluate(test_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Eq6B9lKMfhS6"
      },
      "source": [
        "## Change the Model\n",
        "\n",
        "We could change the model by changing the `model_spec`. The following shows how we change to BERT-base model.\n",
        "\n",
        "First, we could change `model_spec` to `BertModelSpec`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QfFCWrwyggrT"
      },
      "outputs": [],
      "source": [
        "model_spec = BertModelSpec()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "L2d7yycrgu6L"
      },
      "source": [
        "The remaining steps remains the same.\n",
        "\n",
        "Load data and preprocess the data according to `model_spec`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "6GQXQO54iyyE"
      },
      "outputs": [],
      "source": [
        "train_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'train'), model_spec=model_spec, class_labels=['pos', 'neg'])\n",
        "test_data = TextClassifierDataLoader.from_folder(os.path.join(data_path, 'test'), model_spec=model_spec, is_training=False, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZTMqpDXCi11Q"
      },
      "source": [
        "Then retrain the model. Note that it could take a long time to retrain the BERT model. we just set `epochs` equals 1 to demonstrate it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "c991Bdkgi1Bf"
      },
      "outputs": [],
      "source": [
        "model = text_classifier.create(train_data, model_spec=model_spec, epochs=1)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "text_classification.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
